/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.parquet.filter2;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;

public class IncrementallyUpdatedFilterPredicateGenerator {

  public static void main(String[] args) throws IOException {
    File srcFile = new File(
        args[0] + "/org/apache/parquet/filter2/recordlevel/IncrementallyUpdatedFilterPredicateBuilder.java");
    srcFile = srcFile.getAbsoluteFile();
    File parent = srcFile.getParentFile();
    if (!parent.exists()) {
      if (!parent.mkdirs()) {
        throw new IOException("Couldn't mkdirs for " + parent);
      }
    }
    new IncrementallyUpdatedFilterPredicateGenerator(srcFile).run();
  }

  private final FileWriter writer;

  public IncrementallyUpdatedFilterPredicateGenerator(File file) throws IOException {
    this.writer = new FileWriter(file);
  }

  private static class TypeInfo {
    public final String className;
    public final String primitiveName;
    public final boolean supportsInequality;

    private TypeInfo(String className, String primitiveName, boolean supportsInequality) {
      this.className = className;
      this.primitiveName = primitiveName;
      this.supportsInequality = supportsInequality;
    }
  }

  private static final TypeInfo[] TYPES = new TypeInfo[] {
    new TypeInfo("Integer", "int", true),
    new TypeInfo("Long", "long", true),
    new TypeInfo("Boolean", "boolean", false),
    new TypeInfo("Float", "float", true),
    new TypeInfo("Double", "double", true),
    new TypeInfo("Binary", "Binary", true),
  };

  public void run() throws IOException {
    add("package org.apache.parquet.filter2.recordlevel;\n" + "\n"
        + "import java.util.Iterator;\n"
        + "import java.util.List;\n"
        + "import java.util.Set;\n"
        + "\n"
        + "import org.apache.parquet.hadoop.metadata.ColumnPath;\n"
        + "import org.apache.parquet.filter2.predicate.FilterPredicate;\n"
        + "import org.apache.parquet.filter2.predicate.Operators;\n"
        + "import org.apache.parquet.filter2.predicate.Operators.Contains;\n"
        + "import org.apache.parquet.filter2.predicate.Operators.Eq;\n"
        + "import org.apache.parquet.filter2.predicate.Operators.Gt;\n"
        + "import org.apache.parquet.filter2.predicate.Operators.GtEq;\n"
        + "import org.apache.parquet.filter2.predicate.Operators.In;\n"
        + "import org.apache.parquet.filter2.predicate.Operators.LogicalNotUserDefined;\n"
        + "import org.apache.parquet.filter2.predicate.Operators.Lt;\n"
        + "import org.apache.parquet.filter2.predicate.Operators.LtEq;\n"
        + "import org.apache.parquet.filter2.predicate.Operators.NotEq;\n"
        + "import org.apache.parquet.filter2.predicate.Operators.NotIn;\n"
        + "import org.apache.parquet.filter2.predicate.Operators.UserDefined;\n"
        + "import org.apache.parquet.filter2.predicate.UserDefinedPredicate;\n"
        + "import org.apache.parquet.filter2.recordlevel.IncrementallyUpdatedFilterPredicate.ValueInspector;\n"
        + "import org.apache.parquet.filter2.recordlevel.IncrementallyUpdatedFilterPredicate.DelegatingValueInspector;\n"
        + "import org.apache.parquet.io.api.Binary;\n"
        + "import org.apache.parquet.io.PrimitiveColumnIO;\n"
        + "import org.apache.parquet.schema.PrimitiveComparator;\n\n"
        + "/**\n"
        + " * This class is auto-generated by org.apache.parquet.filter2.IncrementallyUpdatedFilterPredicateGenerator\n"
        + " * Do not manually edit!\n"
        + " * See {@link IncrementallyUpdatedFilterPredicateBuilderBase}\n"
        + " */\n");

    add(
        "public class IncrementallyUpdatedFilterPredicateBuilder extends IncrementallyUpdatedFilterPredicateBuilderBase {\n\n");

    add("  public IncrementallyUpdatedFilterPredicateBuilder(List<PrimitiveColumnIO> leaves) {\n"
        + "    super(leaves);\n"
        + "  }\n\n");

    addVisitBegin("Eq");
    for (TypeInfo info : TYPES) {
      addEqNotEqCase(info, true, false);
    }
    addVisitEnd();

    addVisitBegin("NotEq");
    for (TypeInfo info : TYPES) {
      addEqNotEqCase(info, false, false);
    }
    addVisitEnd();

    addVisitBegin("In");
    for (TypeInfo info : TYPES) {
      addInNotInCase(info, true, false);
    }
    addVisitEnd();

    addVisitBegin("NotIn");
    for (TypeInfo info : TYPES) {
      addInNotInCase(info, false, false);
    }
    addVisitEnd();

    addContainsBegin();
    addVisitBegin("Contains");
    addContainsCase();
    addContainsEnd();
    addVisitEnd();

    addVisitBegin("Lt");
    for (TypeInfo info : TYPES) {
      addInequalityCase(info, "<", false);
    }
    addVisitEnd();

    addVisitBegin("LtEq");
    for (TypeInfo info : TYPES) {
      addInequalityCase(info, "<=", false);
    }
    addVisitEnd();

    addVisitBegin("Gt");
    for (TypeInfo info : TYPES) {
      addInequalityCase(info, ">", false);
    }
    addVisitEnd();

    addVisitBegin("GtEq");
    for (TypeInfo info : TYPES) {
      addInequalityCase(info, ">=", false);
    }
    addVisitEnd();

    add(
        "  @Override\n"
            + "  public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> IncrementallyUpdatedFilterPredicate visit(UserDefined<T, U> pred) {\n");
    addUdpBegin();
    for (TypeInfo info : TYPES) {
      addUdpCase(info, false);
    }
    addVisitEnd();

    add("  @Override\n"
        + "  public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> IncrementallyUpdatedFilterPredicate visit(LogicalNotUserDefined<T, U> notPred) {\n"
        + "    UserDefined<T, U> pred = notPred.getUserDefined();\n");
    addUdpBegin();
    for (TypeInfo info : TYPES) {
      addUdpCase(info, true);
    }
    addVisitEnd();

    add("}\n");
    writer.close();
  }

  private void addVisitBegin(String inVar) throws IOException {
    add("  @Override\n" + "  public <T extends Comparable<T>> IncrementallyUpdatedFilterPredicate visit("
        + inVar + "<T> pred) {\n" + "    ColumnPath columnPath = pred.getColumn().getColumnPath();\n"
        + "    Class<T> clazz = pred.getColumn().getColumnType();\n"
        + "\n"
        + "    ValueInspector valueInspector = null;\n\n");
  }

  private void addVisitEnd() throws IOException {
    add("    if (valueInspector == null) {\n"
        + "      throw new IllegalArgumentException(\"Encountered unknown type \" + clazz);\n"
        + "    }\n"
        + "\n"
        + "    addValueInspector(columnPath, valueInspector);\n"
        + "    return valueInspector;\n"
        + "  }\n\n");
  }

  private void addEqNotEqCase(TypeInfo info, boolean isEq, boolean expectMultipleResults) throws IOException {
    add("    if (clazz.equals(" + info.className + ".class)) {\n");

    // Predicates for repeated fields don't need to support null values
    if (!expectMultipleResults) {
      add("      if (pred.getValue() == null) {\n"
          + "        valueInspector = new ValueInspector() {\n"
          + "          @Override\n"
          + "          public void updateNull() {\n"
          + "            setResult("
          + isEq + ");\n" + "          }\n"
          + "\n"
          + "          @Override\n"
          + "          public void update("
          + info.primitiveName + " value) {\n" + "            setResult("
          + !isEq + ");\n" + "          }\n"
          + "        };\n"
          + "      } else {\n");
    }

    add("        final "
        + info.primitiveName + " target = (" + info.className + ") (Object) pred.getValue();\n"
        + "        final PrimitiveComparator<"
        + info.className + "> comparator = getComparator(columnPath);\n" + "\n"
        + "        valueInspector = new ValueInspector() {\n"
        + "          @Override\n"
        + "          public void updateNull() {\n"
        + "            setResult("
        + !isEq + ");\n" + "          }\n"
        + "\n"
        + "          @Override\n"
        + "          public void update("
        + info.primitiveName + " value) {\n");

    if (!expectMultipleResults) {
      add("            setResult(" + compareEquality("value", "target", isEq) + ");\n");
    } else {
      add("            if (!isKnown() && " + compareEquality("value", "target", isEq)
          + ") { setResult(true); }\n");
    }

    add("          }\n        };\n");

    if (!expectMultipleResults) {
      add("      }\n");
    }

    add("    }\n\n");
  }

  private void addInequalityCase(TypeInfo info, String op, boolean expectMultipleResults) throws IOException {
    if (!info.supportsInequality) {
      add("    if (clazz.equals(" + info.className + ".class)) {\n");
      add("      throw new IllegalArgumentException(\"Operator " + op + " not supported for " + info.className
          + "\");\n");
      add("    }\n\n");
      return;
    }

    add("    if (clazz.equals(" + info.className + ".class)) {\n" + "      final "
        + info.primitiveName + " target = (" + info.className + ") (Object) pred.getValue();\n"
        + "      final PrimitiveComparator<"
        + info.className + "> comparator = getComparator(columnPath);\n" + "\n"
        + "      valueInspector = new ValueInspector() {\n"
        + "        @Override\n"
        + "        public void updateNull() {\n"
        + "          setResult(false);\n"
        + "        }\n"
        + "\n"
        + "        @Override\n"
        + "        public void update("
        + info.primitiveName + " value) {\n");

    if (!expectMultipleResults) {
      add("          setResult(comparator.compare(value, target) " + op + " 0);\n");
    } else {
      add("            if (!isKnown() && comparator.compare(value, target) " + op + " 0)"
          + " { setResult(true); }\n");
    }

    add("        }\n" + "      };\n" + "    }\n\n");
  }

  private void addInNotInCase(TypeInfo info, boolean isEq, boolean expectMultipleResults) throws IOException {
    add("    if (clazz.equals(" + info.className + ".class)) {\n" + "      if (pred.getValues().contains(null)) {\n"
        + "        valueInspector = new ValueInspector() {\n"
        + "          @Override\n"
        + "          public void updateNull() {\n"
        + "            setResult("
        + isEq + ");\n" + "          }\n"
        + "\n"
        + "          @Override\n"
        + "          public void update("
        + info.primitiveName + " value) {\n" + "            setResult("
        + !isEq + ");\n" + "          }\n"
        + "        };\n"
        + "      } else {\n"
        + "        final Set<"
        + info.className + "> target = (Set<" + info.className + ">) pred.getValues();\n"
        + "        final PrimitiveComparator<"
        + info.className + "> comparator = getComparator(columnPath);\n" + "\n"
        + "        valueInspector = new ValueInspector() {\n"
        + "          @Override\n"
        + "          public void updateNull() {\n"
        + "            setResult("
        + !isEq + ");\n" + "          }\n"
        + "\n"
        + "          @Override\n"
        + "          public void update("
        + info.primitiveName + " value) {\n");

    if (expectMultipleResults) {
      add("            if (isKnown()) return;\n");
    }
    add("            for (" + info.primitiveName + " i : target) {\n");

    add("              if(" + compareEquality("value", "i", isEq) + ") {\n");

    add("                 setResult(true);\n                 return;\n");

    add("               }\n");

    add("             }\n");
    if (!expectMultipleResults) {
      add("             setResult(false);\n");
    }
    add("           }\n");

    add("         };\n" + "       }\n" + "    }\n\n");
  }

  private void addUdpBegin() throws IOException {
    add("    ColumnPath columnPath = pred.getColumn().getColumnPath();\n"
        + "    Class<T> clazz = pred.getColumn().getColumnType();\n"
        + "\n"
        + "    ValueInspector valueInspector = null;\n"
        + "\n"
        + "    final U udp = pred.getUserDefinedPredicate();\n"
        + "\n");
  }

  private void addContainsInspectorVisitor(String op) throws IOException {
    add("    @Override\n"
        + "    public <T extends Comparable<T>> ContainsPredicate visit(" + op + "<T> pred) {\n"
        + "      ColumnPath columnPath = pred.getColumn().getColumnPath();\n"
        + "      Class<T> clazz = pred.getColumn().getColumnType();\n"
        + "      ValueInspector valueInspector = null;\n");

    for (TypeInfo info : TYPES) {
      switch (op) {
        case "Eq":
          addEqNotEqCase(info, true, true);
          break;
        case "NotEq":
          addEqNotEqCase(info, false, true);
          break;
        case "Lt":
          addInequalityCase(info, "<", true);
          break;
        case "LtEq":
          addInequalityCase(info, "<=", true);
          break;
        case "Gt":
          addInequalityCase(info, ">", true);
          break;
        case "GtEq":
          addInequalityCase(info, ">=", true);
          break;
        case "In":
          addInNotInCase(info, true, true);
          break;
        case "NotIn":
          addInNotInCase(info, false, true);
          break;
        default:
          throw new UnsupportedOperationException("Op " + op + " not implemented for Contains filter");
      }
    }

    add("      return new ContainsSinglePredicate(valueInspector, false);\n" + "    }\n");
  }

  private void addContainsBegin() throws IOException {
    add("  private abstract static class ContainsPredicate extends DelegatingValueInspector {\n"
        + "    ContainsPredicate(ValueInspector... delegates) {\n"
        + "      super(delegates);\n"
        + "    }\n"
        + "\n"
        + "    abstract ContainsPredicate not();\n"
        + "  }\n"
        + "\n"
        + "  private static class ContainsSinglePredicate extends ContainsPredicate {\n"
        + "    private final boolean isNot;\n"
        + " \n"
        + "    private ContainsSinglePredicate(ValueInspector inspector, boolean isNot) {\n"
        + "      super(inspector);\n"
        + "      this.isNot = isNot;\n"
        + "    }\n\n"
        + "    @Override\n"
        + "    ContainsPredicate not() {\n"
        + "      return new ContainsSinglePredicate(getDelegates().iterator().next(), true);\n"
        + "    }\n"
        + "\n"
        + "    @Override\n"
        + "    void onUpdate() {\n"
        + "      if (isKnown()) {\n"
        + "        return;\n"
        + "      }\n"
        + "\n"
        + "      for (ValueInspector inspector : getDelegates()) {\n"
        + "        if (inspector.isKnown() && inspector.getResult()) {\n"
        + "          setResult(!isNot);\n"
        + "          return;\n"
        + "        }\n"
        + "      }\n"
        + "    }\n"
        + "\n"
        + "    @Override\n"
        + "    void onNull() {\n"
        + "      setResult(isNot);\n"
        + "    }\n"
        + "  }\n\n");

    add("  private static class ContainsAndPredicate extends ContainsPredicate {\n"
        + "    private ContainsAndPredicate(ContainsPredicate left, ContainsPredicate right) {\n"
        + "      super(left, right);\n"
        + "    }\n"
        + "\n"
        + "    @Override\n"
        + "    void onUpdate() {\n"
        + "      if (isKnown()) { return; }\n"
        + "\n"
        + "      boolean allKnown = true;\n"
        + "      for (ValueInspector delegate : getDelegates()) {\n"
        + "        if (delegate.isKnown() && !delegate.getResult()) {\n"
        + "          setResult(false);\n"
        + "          return;\n"
        + "        }\n"
        + "        allKnown = allKnown && delegate.isKnown();\n"
        + "      }\n"
        + "      \n"
        + "      if (allKnown) {\n"
        + "        setResult(true);\n"
        + "      }\n"
        + "    }\n"
        + "\n"
        + "    @Override\n"
        + "    void onNull() {\n"
        + "      for (ValueInspector delegate : getDelegates()) {\n"
        + "        if (!delegate.getResult()) {\n"
        + "          setResult(false);\n"
        + "          return;\n"
        + "        }\n"
        + "      }\n"
        + "      setResult(true);\n"
        + "    }\n"
        + "\n"
        + "    @Override\n"
        + "    ContainsPredicate not() {\n"
        + "      Iterator<ValueInspector> it = getDelegates().iterator();\n"
        + "      return new ContainsAndPredicate(((ContainsPredicate) it.next()).not(), ((ContainsPredicate) it.next()).not());\n"
        + "    }\n"
        + "  }\n\n");

    add("  private static class ContainsOrPredicate extends ContainsPredicate {\n"
        + "    private ContainsOrPredicate(ContainsPredicate left, ContainsPredicate right) {\n"
        + "      super(left, right);\n"
        + "    }\n"
        + "\n"
        + "    @Override\n"
        + "    void onUpdate() {\n"
        + "      if (isKnown()) { return; }\n"
        + "\n"
        + "      for (ValueInspector delegate : getDelegates()) {\n"
        + "        if (delegate.isKnown() && delegate.getResult()) {\n"
        + "          setResult(true);\n"
        + "        }\n"
        + "      }\n"
        + "    }\n"
        + "\n"
        + "    @Override\n"
        + "    void onNull() {\n"
        + "      for (ValueInspector delegate : getDelegates()) {\n"
        + "        if (delegate.getResult()) {\n"
        + "          setResult(true);\n"
        + "          return;\n"
        + "        }\n"
        + "      }\n"
        + "      setResult(false);\n"
        + "    }\n"
        + "\n"
        + "    @Override\n"
        + "    ContainsPredicate not() {\n"
        + "      Iterator<ValueInspector> it = getDelegates().iterator();\n"
        + "      return new ContainsOrPredicate(((ContainsPredicate) it.next()).not(), ((ContainsPredicate) it.next()).not());\n"
        + "    }\n"
        + "  }\n\n");

    add("  private class ContainsInspectorVisitor implements FilterPredicate.Visitor<ContainsPredicate> {\n\n"
        + "    @Override\n"
        + "    public <T extends Comparable<T>> ContainsPredicate visit(Contains<T> contains) {\n"
        + "      return contains.filter(this, ContainsAndPredicate::new, ContainsOrPredicate::new, ContainsPredicate::not);\n"
        + "    }\n");

    addContainsInspectorVisitor("Eq");
    addContainsInspectorVisitor("NotEq");
    addContainsInspectorVisitor("Lt");
    addContainsInspectorVisitor("LtEq");
    addContainsInspectorVisitor("Gt");
    addContainsInspectorVisitor("GtEq");
    addContainsInspectorVisitor("In");
    addContainsInspectorVisitor("NotIn");

    add("    @Override\n"
        + "    public ContainsPredicate visit(Operators.And pred) {\n"
        + "      throw new UnsupportedOperationException(\"Operators.And not supported for Contains predicate\");\n"
        + "    }\n"
        + "\n"
        + "    @Override\n"
        + "    public ContainsPredicate visit(Operators.Or pred) {\n"
        + "      throw new UnsupportedOperationException(\"Operators.Or not supported for Contains predicate\");\n"
        + "    }\n"
        + "\n"
        + "    @Override\n"
        + "    public ContainsPredicate visit(Operators.Not pred) {\n"
        + "      throw new UnsupportedOperationException(\"Operators.Not not supported for Contains predicate\");\n"
        + "    }"
        + "    @Override\n"
        + "    public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> ContainsPredicate visit(\n"
        + "        UserDefined<T, U> pred) {\n"
        + "      throw new UnsupportedOperationException(\"UserDefinedPredicate not supported for Contains predicate\");\n"
        + "    }\n"
        + "\n"
        + "    @Override\n"
        + "    public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> ContainsPredicate visit(\n"
        + "        LogicalNotUserDefined<T, U> pred) {\n"
        + "      throw new UnsupportedOperationException(\"LogicalNotUserDefined not supported for Contains predicate\");\n"
        + "    }\n"
        + "  }\n"
        + "\n");
  }

  private void addContainsCase() throws IOException {
    add("    valueInspector = new ContainsInspectorVisitor().visit(pred);\n");
  }

  private void addContainsEnd() {
    // No-op
  }

  private void addUdpCase(TypeInfo info, boolean invert) throws IOException {
    add("    if (clazz.equals(" + info.className + ".class)) {\n"
        + "      valueInspector = new ValueInspector() {\n"
        + "        @Override\n"
        + "        public void updateNull() {\n"
        + "          setResult("
        + (invert ? "!" : "") + "udp.acceptsNullValue());\n" + "        }\n"
        + "\n"
        + "        @SuppressWarnings(\"unchecked\")\n"
        + "        @Override\n"
        + "        public void update("
        + info.primitiveName + " value) {\n" + "          setResult("
        + (invert ? "!" : "") + "udp.keep((T) (Object) value));\n" + "        }\n"
        + "      };\n"
        + "    }\n\n");
  }

  private String compareEquality(String var, String target, boolean eq) {
    return "comparator.compare(" + var + ", " + target + ")" + (eq ? " == 0 " : " != 0");
  }

  private void add(String s) throws IOException {
    writer.write(s);
  }
}
